import torch
import torch.nn as nn
import torch.optim as optim

class aa_bp(nn.Module):

    def __init__(self, n_features=243,  num_classes=18, random_state=None):
        super(aa_bp, self).__init__()
        self._n_features = n_features
        self._n_outputs = num_classes
        self._latent = self._n_features * 3 // 2
        self.num_features = self._latent
        self._model = nn.Sequential(
            nn.Linear(self._n_features,self._latent),
            #nn.Sigmoid(),
            #nn.Linear(1024, self._latent),
            nn.Sigmoid(),
        )
        self.classifier = nn.Linear(self._latent, self._n_outputs)

    def forward(self, x, only_fc=False, only_feat=False, **kwargs):
        """
        Args:
            x: input tensor, depends on only_fc and only_feat flag
            only_fc: only use classifier, input should be features before classifier
            only_feat: only return pooled features
        """
        x = torch.squeeze(x)
        if only_fc:
            return self.classifier(x)

        x = self.extract(x)
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)

        if only_feat:
            return x

        out = self.classifier(x)
        result_dict = {'logits':out, 'feat':x}
        return result_dict
    
    
    def extract(self, x):
        x = self._model(x)
        return x
    
    def group_matcher(self, coarse=False, prefix=''):
        matcher = {}
        return matcher
    
    def no_weight_decay(self):
        nwd = []
        for n, _ in self.named_parameters():
            if 'bn' in n or 'bias' in n:
                nwd.append(n)
        return nwd

def AA_BP(pretrained=False, pretrained_path=None, **kwargs):
    model = aa_bp(n_features=24,  num_classes=18)
    return model